{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 张量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tvm\n",
    "import tvm.testing\n",
    "from tvm import te"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(shape=[m, n, l], op.name=compute)\n",
      "[A[i, k] * B[j, k]]\n"
     ]
    }
   ],
   "source": [
    "m = te.size_var(\"m\")\n",
    "n = te.size_var(\"n\")\n",
    "l = te.size_var(\"l\")\n",
    "A = te.placeholder((m, l), name=\"A\")\n",
    "B = te.placeholder((n, l), name=\"B\")\n",
    "T = te.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k])\n",
    "print(T)\n",
    "print(T.op.body)\n",
    "assert tuple(T.shape) == (m, n, l)\n",
    "assert isinstance(A.op, tvm.te.PlaceholderOp)\n",
    "assert A == A\n",
    "assert T.op.output(0) == T\n",
    "assert T.op.output(0).__hash__() == T.__hash__()\n",
    "d = {T.op.output(0): 1}\n",
    "assert d[T] == 1\n",
    "assert T[0][0][0].astype(\"float16\").dtype == \"float16\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## rank_zero"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tensor(shape=[], op.name=compute)\n",
      "[T.reduce(T.comm_reducer(lambda x, y: x + y, [T.float32(0.0)]), source=[A[k] * s[()]], init=[], axis=[T.iter_var(k, T.Range(0, m), \"CommReduce\", \"\")], condition=T.bool(True), value_index=0)]\n"
     ]
    }
   ],
   "source": [
    "m = te.size_var(\"m\")\n",
    "A = te.placeholder((m,), name=\"A\")\n",
    "scale = te.placeholder((), name=\"s\")\n",
    "k = te.reduce_axis((0, m), name=\"k\")\n",
    "T = te.compute((), lambda: te.sum(A[k] * scale(), axis=k))\n",
    "print(T)\n",
    "print(T.op.body)\n",
    "assert tuple(T.shape) == ()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## conv1d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = te.size_var(\"n\")\n",
    "A = te.placeholder((n + 2), name=\"A\")\n",
    "\n",
    "def computeB(ii):\n",
    "    i = ii + 1\n",
    "    return A[i - 1] + A[i] + A[i + 1]\n",
    "\n",
    "B = te.compute(n, computeB)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## slice"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = te.size_var(\"n\")\n",
    "A = te.compute((n, n), lambda i, j: 1)\n",
    "B = te.compute((n,), lambda i: A[0][i] + A[0][i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## reduce_multi_axis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = te.size_var(\"m\")\n",
    "n = te.size_var(\"n\")\n",
    "A = te.placeholder((m, n), name=\"A\")\n",
    "k1 = te.reduce_axis((0, n), \"k\")\n",
    "k2 = te.reduce_axis((0, m), \"k\")\n",
    "C = te.compute((1,), lambda _: te.sum(A[k1, k2], axis=(k1, k2)))\n",
    "C = te.compute((1,), lambda _: te.sum(A[k1, k2], axis=[k1, k2]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## comm_reducer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = te.size_var(\"m\")\n",
    "n = te.size_var(\"n\")\n",
    "A = te.placeholder((m, n), name=\"A\")\n",
    "k = te.reduce_axis((0, n), \"k\")\n",
    "mysum = te.comm_reducer(lambda x, y: x + y, lambda t: tvm.tir.const(0, dtype=t))\n",
    "C = te.compute((m,), lambda i: mysum(A[i, k], axis=k))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## comm_reducer_overload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = te.size_var(\"m\")\n",
    "n = te.size_var(\"n\")\n",
    "mysum = te.comm_reducer(lambda x, y: x + y, lambda t: tvm.tir.const(0, dtype=t))\n",
    "sum_res = mysum(m, n)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## reduce"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = te.size_var(\"m\")\n",
    "n = te.size_var(\"n\")\n",
    "l = te.size_var(\"l\")\n",
    "A = te.placeholder((m, l), name=\"A\")\n",
    "B = te.placeholder((n, l), name=\"B\")\n",
    "T = te.compute((m, n, l), lambda i, j, k: A[i, k] * B[j, k])\n",
    "rv = te.reduce_axis((0, A.shape[1]), \"k\")\n",
    "C = te.compute((m, n), lambda i, j: te.sum(T(i, j, rv + 1), axis=rv))\n",
    "# json load save\n",
    "C_json = tvm.ir.save_json(C)\n",
    "C_loaded = tvm.ir.load_json(C_json)\n",
    "assert isinstance(C_loaded, te.tensor.Tensor)\n",
    "assert str(C_loaded) == str(C)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## reduce_multiout_with_cond"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fcombine(x, y):\n",
    "    return x[0] + y[0], x[1] + y[1]\n",
    "\n",
    "def fidentity(t0, t1):\n",
    "    return tvm.tir.const(0, t0), tvm.tir.const(1, t1)\n",
    "\n",
    "mysum = te.comm_reducer(fcombine, fidentity, name=\"mysum\")\n",
    "\n",
    "m = te.var(\"m\")\n",
    "n = te.var(\"n\")\n",
    "idx = te.placeholder((m, n), name=\"idx\", dtype=\"int32\")\n",
    "val = te.placeholder((m, n), name=\"val\", dtype=\"int32\")\n",
    "k = te.reduce_axis((0, n), \"k\")\n",
    "cond = te.floormod(k, 2) == 0\n",
    "T0, T1 = te.compute((m,), lambda i: mysum((idx[i, k], val[i, k]), axis=k, where=cond), name=\"T\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## scan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = te.size_var(\"m\")\n",
    "n = te.size_var(\"n\")\n",
    "x = te.placeholder((m, n))\n",
    "s = te.placeholder((m, n))\n",
    "res = tvm.te.scan(\n",
    "    te.compute((1, n), lambda _, i: x[0, i]),\n",
    "    te.compute((m, n), lambda t, i: s[t - 1, i] + x[t, i]),\n",
    "    s,\n",
    ")\n",
    "assert tuple(res.shape) == (m, n)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## scan_multi_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = te.size_var(\"m\")\n",
    "n = te.size_var(\"n\")\n",
    "x1 = te.placeholder((m, n))\n",
    "s1 = te.placeholder((m, n))\n",
    "x2 = te.placeholder((m, n))\n",
    "s2 = te.placeholder((m, n))\n",
    "s1_init = te.compute((1, n), lambda _, i: x1[0, i])\n",
    "s2_init = te.compute((1, n), lambda _, i: x2[0, i])\n",
    "s1_update = te.compute((m, n), lambda t, i: s1[t - 1, i] + s2[t - 1, i] + x1[t, i])\n",
    "s2_update = te.compute((m, n), lambda t, i: x2[t, i] + s2[t - 1, i])\n",
    "\n",
    "r0, r1 = tvm.te.scan([s1_init, s2_init], [s1_update, s2_update], [s1, s2])\n",
    "assert r0.value_index == 0\n",
    "assert r1.value_index == 1\n",
    "json_str = tvm.ir.save_json(r0.op)\n",
    "zz = tvm.ir.load_json(json_str)\n",
    "assert isinstance(zz, tvm.te.ScanOp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## extern"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = te.size_var(\"m\")\n",
    "A = te.placeholder((m,), name=\"A\")\n",
    "\n",
    "def extern_func(ins, outs):\n",
    "    assert isinstance(ins[0], tvm.tir.Buffer)\n",
    "    return tvm.tir.call_packed(\"myadd\", ins[0].data, outs[0].data, m)\n",
    "\n",
    "B = te.extern((m,), [A], extern_func)\n",
    "assert tuple(B.shape) == (m,)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## extern_multi_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = te.size_var(\"m\")\n",
    "A = te.placeholder((m,), name=\"A\")\n",
    "B = te.compute((m,), lambda i: A[i] * 10)\n",
    "\n",
    "def extern_func(ins, outs):\n",
    "    assert isinstance(ins[0], tvm.tir.Buffer)\n",
    "    return tvm.tir.call_packed(\"myadd\", ins[0].data, outs[0].data, outs[1].data, m)\n",
    "\n",
    "res = te.extern([A.shape, A.shape], [A, B], extern_func)\n",
    "assert len(res) == 2\n",
    "assert res[1].value_index == 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## tuple_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "m = te.size_var(\"m\")\n",
    "n = te.size_var(\"n\")\n",
    "A0 = te.placeholder((m, n), name=\"A0\")\n",
    "A1 = te.placeholder((m, n), name=\"A1\")\n",
    "T0, T1 = te.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name=\"T\")\n",
    "s = te.create_prim_func([A0, A1, T0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## tuple_with_different_deps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "# from tvm.script import tir as T\n",
       "\n",
       "@T.prim_func\n",
       "def main(var_A1: T.handle, var_A2: T.handle, var_C: T.handle):\n",
       "    T.func_attr({\"tir.noalias\": True})\n",
       "    m, n = T.int32(is_size_var=True), T.int32(is_size_var=True)\n",
       "    A1 = T.match_buffer(var_A1, (m, n))\n",
       "    A2 = T.match_buffer(var_A2, (m, n))\n",
       "    C = T.match_buffer(var_C, (m, n))\n",
       "    # with T.block(\"root\"):\n",
       "    B_v0 = T.alloc_buffer((m, n))\n",
       "    B_v1 = T.alloc_buffer((m, n))\n",
       "    for i, j in T.grid(m, n):\n",
       "        with T.block(\"B_v0\"):\n",
       "            v_i, v_j = T.axis.remap(\"SS\", [i, j])\n",
       "            T.reads(A1[v_i, v_j])\n",
       "            T.writes(B_v0[v_i, v_j])\n",
       "            B_v0[v_i, v_j] = A1[v_i, v_j] * T.float32(2.0)\n",
       "        with T.block(\"B_v1\"):\n",
       "            v_i, v_j = T.axis.remap(\"SS\", [i, j])\n",
       "            T.reads(A2[v_i, v_j])\n",
       "            T.writes(B_v1[v_i, v_j])\n",
       "            B_v1[v_i, v_j] = A2[v_i, v_j] * T.float32(3.0)\n",
       "    for i, j in T.grid(m, n):\n",
       "        with T.block(\"C\"):\n",
       "            v_i, v_j = T.axis.remap(\"SS\", [i, j])\n",
       "            T.reads(B_v0[v_i, v_j])\n",
       "            T.writes(C[v_i, v_j])\n",
       "            C[v_i, v_j] = B_v0[v_i, v_j] + T.float32(4.0)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "m = te.size_var(\"m\")\n",
    "n = te.size_var(\"n\")\n",
    "A0 = te.placeholder((m, n), name=\"A1\")\n",
    "A1 = te.placeholder((m, n), name=\"A2\")\n",
    "B0, B1 = te.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name=\"B\")\n",
    "C = te.compute((m, n), lambda i, j: B0[i, j] + 4, name=\"C\")\n",
    "\n",
    "te.create_prim_func([A0, A1, C])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## tensor_inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = te.placeholder((1,), name=\"x\")\n",
    "y = te.compute(x.shape, lambda i: x[i] + x[i])\n",
    "assert tuple(y.op.input_tensors) == (x,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
